[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629
Conversation
Greptile SummaryThis PR fixes a bug in THD-format MLA (e.g. DeepSeek V3) where FlashAttention 2 was entirely blocked for mismatched Q/K and V head dimensions. It introduces zero-padding of V (and optionally Q/K) up to
Confidence Score: 3/5The non-FP8 THD MLA path works correctly, but the removed FA2 guard combined with the Float8TensorStorage exclusion from padding leaves the Float8 + MLA + FA2 combination unprotected — FA2 would receive tensors with mismatched head dimensions. The removed blanket FA2 guard for head_dim_qk != head_dim_v in utils.py is not fully compensated by the padding logic in dot_product_attention.py, which skips Float8TensorStorage inputs. If a Float8 MLA configuration reaches FA2, it will call FA2 with unpadded mismatched head dims, causing a crash or incorrect results. The same guard previously covered this case safely. Both changed files interact to create the regression: utils.py removes the guard that blocked FA2 for all mismatched-head-dim cases, while dot_product_attention.py adds padding but excludes Float8 tensors. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["DotProductAttention.forward - MLA head_dim_qk != head_dim_v"] --> B{use_flash_attention?}
B -- No --> C[FusedAttention or Unfused]
B -- Yes --> D{backend == FA2 version?}
D -- No --> E[FA3/FA4 support MLA natively - no padding needed]
D -- Yes --> F{value is Float8TensorStorage?}
F -- Yes --> G["Skip padding - FA2 receives mismatched head dims - potential crash"]
F -- No --> H[_pad_qkv_head_dim - pad V to head_dim_qk]
H --> I[flash_attention with padded Q/K/V]
I --> J{orig_qk_dim > orig_v_dim?}
J -- Yes --> K[_trim_output - slice back to orig_head_dim_v]
J -- No --> L[Return attn_out as-is]
K --> M[Correct output]
L --> M
|
There was a problem hiding this comment.
Pull request overview
This PR adds support for Multi-head Latent Attention (MLA) with mismatched Q/V head dimensions in the THD (Total-Hidden-Dimension) format. When the value tensor has a smaller head dimension than the query/key tensors, the code pads the value tensor to match the Q/K head dimension, runs the attention operation, and then trims the output back to the original V dimension.
Changes:
- Added padding logic for V tensor when head dimensions differ in THD format
- Implemented trimming function to restore correct output dimensions after attention
- Added test case for THD attention with mismatched Q/V head dimensions
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py | Implements padding of V tensor before attention and trimming of output after attention for THD format with mismatched Q/V head dimensions |
| tests/pytorch/attention/test_attention.py | Adds test case to verify THD attention works with different Q/V head dimensions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
This change should only be required by the FlashAttention backend. The other two backends FusedAttention and UnfusedDPA do support MLA (head_dim_qk != head_dim_v). I'd propose a few changes:
@vcherepanov-nv, could you help push this PR through the finish line? Thanks! |
|
Thank you @cyanguwa, I just cleaned up the PR and also follow your requirements. Please let me know what you think @vcherepanov-nv. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
/te-ci pytorch L0 |
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
@HollowMan6, could you help fix the failed tests please? Sorry, it's an oversight on my side too. Right now, _pad_value_layer and _trim_output both assume that V has a shorter head_dim than Q/K, but it could happen the other way as well.
// failed tests: "mla_1_0", "mla_1_1"
TransformerEngine/tests/pytorch/attention/test_attention.py
Lines 568 to 569 in 5535b09
// failed error:
https://github.com/Dao-AILab/flash-attention/blob/d80a77103021c4e980f8cbbf85774f6a19e6474a/csrc/flash_attn/flash_api.cpp#L418
I wonder if we can make the pad function look something like this:
def _pad_qkv_head_dim(query_layer, key_layer, value_layer):
return new_q, new_k, new_v, orig_head_dim_qk, orig_head_dim_v
Also, only call _trim_output on padded_head_dim_v > orig_head_dim_v; otherwise, a no op.
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
|
Thank you for pointing this out @cyanguwa, originally I didn't handle this v > qk as this is not a practice for MLA, but since test cases cover this, I have just pushed the changes accordingly. |
|
/te-ci pytorch L0 |
Description
For MLA, we shall pad V when Q/V head dims differ for THD
Similar to NVIDIA/Megatron-LM#3003
Fixes NVIDIA/Megatron-LM#1698
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: